Chapter 5: Datasets and models

Read datasets and models

library("DALEX")
library("randomForest")
library("patchwork")
library("ggplot2")
set.seed(1313)

titanic_rf <- randomForest(survived ~ class + gender + age + sibsp + parch + fare + embarked, data = titanic_imputed)

henry <- data.frame(
         class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", 
                     "engineering crew", "restaurant staff", "victualling crew")),
         gender = factor("male", levels = c("female", "male")),
         age = 47,
         sibsp = 0,
         parch = 0,
         fare = 25,
         embarked = factor("Cherbourg", levels = c("Belfast",
                           "Cherbourg","Queenstown","Southampton"))
)
henry
johnny_d <- data.frame(
            class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew",
                        "engineering crew", "restaurant staff", "victualling crew")),
            gender = factor("male", levels = c("female", "male")),
            age = 8,
            sibsp = 0,
            parch = 0,
            fare = 72,
            embarked = factor("Southampton", levels = c("Belfast",
                        "Cherbourg","Queenstown","Southampton"))
)
johnny_d
titanic_rf_exp <- DALEX::explain(model = titanic_rf,  
                          data = titanic_imputed[, -9],
                             y = titanic_imputed$survived, 
                         label = "Random Forest")
## Preparation of a new explainer is initiated
##   -> model label       :  Random Forest 
##   -> data              :  2207  rows  8  cols 
##   -> target variable   :  2207  values 
##   -> predict function  :  yhat.randomForest  will be used (  default  )
##   -> predicted values  :  numerical, min =  0.01590278 , mean =  0.3222722 , max =  0.9900173  
##   -> model_info        :  package randomForest , ver. 4.6.14 , task regression (  default  ) 
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.7970723 , mean =  -0.0001153935 , max =  0.8992474  
##   A new explainer has been created! 
titanic_rf_exp$model_info$type = "classification"

Chapter 7: Break-down Plots for Additive Attributions

Examples

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                 new_observation = johnny_d,
              keep_distributions = TRUE,
                  order = c("class","age","gender","fare","parch","sibsp","embarked"),
                            type = "break_down")
bd_rf

Plot the break down plots

plot(bd_rf) 

plot(bd_rf, plot_distributions = TRUE) 

Basic use of the perdict_parts() function

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                 new_observation = henry,
                            type = "break_down")
bd_rf

Plot the break down plots

plot(bd_rf) 

Advanced use of the predict_parts() function

bd_rf_order <- predict_parts(explainer = titanic_rf_exp,
                              new_observation = henry, 
                                         type = "break_down",
               order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"))
plot(bd_rf_order, max_features = 3) 

bd_rf_distr <- predict_parts(explainer = titanic_rf_exp,
                              new_observation = henry, 
                                        type = "break_down",
          order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"),
                          keep_distributions = TRUE)
plot(bd_rf_distr, plot_distributions = TRUE) 

Chapter 8: Break-down Plots for Interactions (iBreak-down Plots)

Examples

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                        new_observation = johnny_d,
                                   type = "break_down_interactions")
bd_rf
plot(bd_rf) 

Code snippets for R

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                        new_observation = henry,
                                   type = "break_down_interactions")
bd_rf
plot(bd_rf) 

Chapter 9: Shapley Additive Explanations (SHAP) and Average Variable Attributions

set.seed(13)

rsample <- lapply(1:10, function(i){
  new_order <- sample(1:7)
  bd <- predict_parts(titanic_rf_exp, johnny_d, order = new_order)
  bd$variable <- as.character(bd$variable)
  bd$variable[bd$variable == "embarked = Southampton"] = "embarked = S"
  bd$label = paste("random order no.", i)
  plot(bd) + scale_y_continuous(limits = c(0.1, 0.6), name = "", breaks = seq(0.1, 0.6, 0.1))
})

rsample[[1]] +
rsample[[2]] +
rsample[[3]] +
rsample[[4]] + 
rsample[[5]] + 
rsample[[6]] + 
rsample[[7]] + 
rsample[[8]] + 
rsample[[9]] + 
rsample[[10]] + plot_layout(ncol = 2)

shap_johnny <- predict_parts(titanic_rf_exp,
                 new_observation = johnny_d,
                 B = 25,
                 type = "shap")

Example: Titanic data

Code snippets for R

predict(titanic_rf_exp, henry)
##         1 
## 0.3081968
shap_henry <- predict_parts(explainer = titanic_rf_exp, 
                             new_observation = henry, 
                                        type = "shap",
                                           B = 25)
shap_henry
plot(shap_henry) 

plot(shap_henry, show_boxplots = FALSE) 

Session info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.3
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] ggplot2_3.3.2       patchwork_1.0.1     randomForest_4.6-14
## [4] DALEX_1.3.1        
## 
## loaded via a namespace (and not attached):
##  [1] pillar_1.4.6     compiler_4.0.2   tools_4.0.2      digest_0.6.25   
##  [5] jsonlite_1.7.0   evaluate_0.14    lifecycle_0.2.0  tibble_3.0.3    
##  [9] gtable_0.3.0     pkgconfig_2.0.3  png_0.1-7        rlang_0.4.7     
## [13] yaml_2.2.1       xfun_0.15        withr_2.2.0      stringr_1.4.0   
## [17] dplyr_1.0.0      knitr_1.29       generics_0.0.2   vctrs_0.3.2     
## [21] grid_4.0.2       tidyselect_1.1.0 glue_1.4.1       R6_2.4.1        
## [25] rmarkdown_2.3    iBreakDown_1.3.0 farver_2.0.3     purrr_0.3.4     
## [29] magrittr_1.5     scales_1.1.1     ellipsis_0.3.1   htmltools_0.5.0 
## [33] colorspace_1.4-1 labeling_0.3     stringi_1.4.6    munsell_0.5.0   
## [37] crayon_1.3.4